{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "06_GAN_double_moon_empty_colab.ipynb",
      "version": "0.3.2",
      "provenance": [],
      "toc_visible": true,
      "include_colab_link": true
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.6.6"
    },
    "kernelspec": {
      "display_name": "Python [conda env:pytorch]",
      "language": "python",
      "name": "conda-env-pytorch-py"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/mlelarge/dataflowr/blob/master/CEA_EDF_INRIA/06_GAN_double_moon_empty_colab.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "vjjWPrehQq2f",
        "colab_type": "text"
      },
      "source": [
        "# Generative Adversarial Networks\n",
        "\n",
        "In this notebook, we play with the GAN described in the [course](https://mlelarge.github.io/dataflowr-slides/Slides/GAN/index.html) on a double moon dataset.\n",
        "\n",
        "Then we implement a Conditional GAN and an InfoGAN."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "11loN_KhQq2h",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# all of these libraries are used for plotting\n",
        "import numpy as np\n",
        "import matplotlib.pyplot as plt\n",
        "\n",
        "# Plot the dataset\n",
        "def plot_data(ax, X, Y, color = 'bone'):\n",
        "    plt.axis('off')\n",
        "    ax.scatter(X[:, 0], X[:, 1], s=1, c=Y, cmap=color)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "BJaazazFQq2l",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "from sklearn.datasets import make_moons\n",
        "X, y = make_moons(n_samples=2000, noise=0.05)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "tuN8exxyQq2q",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "n_samples = X.shape[0]\n",
        "Y = np.ones(n_samples)\n",
        "fig, ax = plt.subplots(1, 1, facecolor='#4B6EA9')\n",
        "\n",
        "plot_data(ax, X, Y)\n",
        "plt.show()"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "a6v3mQfQQq2w",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "import torch\n",
        "use_gpu = torch.cuda.is_available()\n",
        "def gpu(tensor, gpu=use_gpu):\n",
        "    if gpu:\n",
        "        return tensor.cuda()\n",
        "    else:\n",
        "        return tensor"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "f_gXNJRMQq2z",
        "colab_type": "text"
      },
      "source": [
        "# A simple GAN\n",
        "\n",
        "We start with the simple GAN described in the course."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "7_IuhU4OQq20",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "import torch.nn as nn\n",
        "\n",
        "z_dim = 32\n",
        "hidden_dim = 128\n",
        "\n",
        "net_G = nn.Sequential(nn.Linear(z_dim,hidden_dim),\n",
        "                     nn.ReLU(), nn.Linear(hidden_dim, 2))\n",
        "\n",
        "net_D = nn.Sequential(nn.Linear(2,hidden_dim),\n",
        "                     nn.ReLU(),\n",
        "                     nn.Linear(hidden_dim,1),\n",
        "                     nn.Sigmoid())\n",
        "\n",
        "net_G = gpu(net_G)\n",
        "net_D = gpu(net_D)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "cV9S8Tw9Qq24",
        "colab_type": "text"
      },
      "source": [
        "Training loop as described in the course, keeping the losses for the discriminator and the generator."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "BOQnI3BkQq25",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "batch_size = 50\n",
        "lr = 1e-4\n",
        "nb_epochs = 1000\n",
        "\n",
        "optimizer_G = torch.optim.Adam(net_G.parameters(),lr=lr)\n",
        "optimizer_D = torch.optim.Adam(net_D.parameters(),lr=lr)\n",
        "\n",
        "loss_D_epoch = []\n",
        "loss_G_epoch = []\n",
        "for e in range(nb_epochs):\n",
        "    np.random.shuffle(X)\n",
        "    real_samples = torch.from_numpy(X).type(torch.FloatTensor)\n",
        "    loss_G = 0\n",
        "    loss_D = 0\n",
        "    for t, real_batch in enumerate(real_samples.split(batch_size)):\n",
        "            #improving D\n",
        "        z = gpu(torch.empty(batch_size,z_dim).normal_())\n",
        "        fake_batch = net_G(z)\n",
        "        D_scores_on_real = net_D(gpu(real_batch))\n",
        "        D_scores_on_fake = net_D(fake_batch)\n",
        "            \n",
        "        loss = -torch.mean(torch.log(1-D_scores_on_fake) + torch.log(D_scores_on_real))\n",
        "        optimizer_D.zero_grad()\n",
        "        loss.backward()\n",
        "        optimizer_D.step()\n",
        "        loss_D += loss\n",
        "                    \n",
        "            # improving G\n",
        "        z = gpu(torch.empty(batch_size,z_dim).normal_())\n",
        "        fake_batch = net_G(z)\n",
        "        D_scores_on_fake = net_D(fake_batch)\n",
        "            \n",
        "        loss = -torch.mean(torch.log(D_scores_on_fake))\n",
        "        optimizer_G.zero_grad()\n",
        "        loss.backward()\n",
        "        optimizer_G.step()\n",
        "        loss_G += loss\n",
        "           \n",
        "    loss_D_epoch.append(loss_D)\n",
        "    loss_G_epoch.append(loss_G)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "y-n8FyUWQq28",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "plt.plot(loss_D_epoch)\n",
        "plt.plot(loss_G_epoch)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Lry_pUAgQq3A",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "z = gpu(torch.empty(n_samples,z_dim).normal_())\n",
        "fake_samples = net_G(z)\n",
        "fake_data = fake_samples.cpu().data.numpy()"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "x91JLigbQq3D",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "fig, ax = plt.subplots(1, 1, facecolor='#4B6EA9')\n",
        "all_data = np.concatenate((X,fake_data),axis=0)\n",
        "Y2 = np.concatenate((np.ones(n_samples),np.zeros(n_samples)))\n",
        "plot_data(ax, all_data, Y2)\n",
        "plt.show();"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MAotTqPRQq3G",
        "colab_type": "text"
      },
      "source": [
        "It looks like the GAN is oscillating. Try again with lr=1e-3"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "AWFz3u-IQq3H",
        "colab_type": "text"
      },
      "source": [
        "We can generate more points:"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ARf29UCyQq3I",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "z = gpu(torch.empty(10*n_samples,z_dim).normal_())\n",
        "fake_samples = net_G(z)\n",
        "fake_data = fake_samples.cpu().data.numpy()\n",
        "fig, ax = plt.subplots(1, 1, facecolor='#4B6EA9')\n",
        "all_data = np.concatenate((X,fake_data),axis=0)\n",
        "Y2 = np.concatenate((np.ones(n_samples),np.zeros(10*n_samples)))\n",
        "plot_data(ax, all_data, Y2)\n",
        "plt.show();"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_v3oo0rwQq3L",
        "colab_type": "text"
      },
      "source": [
        "# Conditional GAN\n",
        "\n",
        "We are now implementing a [conditional GAN](https://arxiv.org/abs/1411.1784).\n",
        "\n",
        "We start by separating the two half moons in two clusters as follows:"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "-A1NdzdYQq3M",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "X, Y = make_moons(n_samples=2000, noise=0.05)\n",
        "n_samples = X.shape[0]\n",
        "fig, ax = plt.subplots(1, 1, facecolor='#4B6EA9')\n",
        "plot_data(ax, X, Y)\n",
        "plt.show()"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Yiyjy8sdQq3Q",
        "colab_type": "text"
      },
      "source": [
        "The task is now given a white or black label to generate points in the corresponding cluster.\n",
        "\n",
        "Both the generator and the discriminator take in addition a one hot encoding of the label. The generator will now generate fake points corresponding to the input label. The discriminator, given a pair of sample and label should detect if this is a fake or a real pair."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "P_YhYT5cQq3R",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "z_dim = 32\n",
        "hidden_dim = 128\n",
        "label_dim = 2\n",
        "\n",
        "\n",
        "class generator(nn.Module):\n",
        "    def __init__(self,z_dim = z_dim, label_dim=label_dim,hidden_dim =hidden_dim):\n",
        "        super(generator,self).__init__()\n",
        "        self.net = nn.Sequential(nn.Linear(z_dim+label_dim,hidden_dim),\n",
        "                     nn.ReLU(), nn.Linear(hidden_dim, 2))\n",
        "        \n",
        "    def forward(self, input, label_onehot):\n",
        "        x = torch.cat([input, label_onehot], 1)\n",
        "        return self.net(x)\n",
        "    \n",
        "class discriminator(nn.Module):\n",
        "    def __init__(self,z_dim = z_dim, label_dim=label_dim,hidden_dim =hidden_dim):\n",
        "        super(discriminator,self).__init__()\n",
        "        self.net =  nn.Sequential(nn.Linear(2+label_dim,hidden_dim),\n",
        "                     nn.ReLU(),\n",
        "                     nn.Linear(hidden_dim,1),\n",
        "                     nn.Sigmoid())\n",
        "        \n",
        "    def forward(self, input, label_onehot):\n",
        "        x = torch.cat([input, label_onehot], 1)\n",
        "        return self.net(x)\n",
        "        \n",
        "\n",
        "net_CG = gpu(generator())\n",
        "net_CD = gpu(discriminator())"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "lkEr5NEUQq3V",
        "colab_type": "text"
      },
      "source": [
        "You need to code the training loop:"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "CdaV-fRBQq3W",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "batch_size = 50\n",
        "lr = 1e-3\n",
        "nb_epochs = 1000\n",
        "\n",
        "optimizer_CG = torch.optim.Adam(net_CG.parameters(),lr=lr)\n",
        "optimizer_CD = torch.optim.Adam(net_CD.parameters(),lr=lr)\n",
        "loss_D_epoch = []\n",
        "loss_G_epoch = []\n",
        "for e in range(nb_epochs):\n",
        "    rperm = np.random.permutation(X.shape[0]);\n",
        "    np.take(X,rperm,axis=0,out=X);\n",
        "    np.take(Y,rperm,axis=0,out=Y);\n",
        "    real_samples = torch.from_numpy(X).type(torch.FloatTensor)\n",
        "    real_labels = torch.from_numpy(Y).type(torch.LongTensor)\n",
        "    loss_G = 0\n",
        "    loss_D = 0\n",
        "    for real_batch, real_batch_label in zip(real_samples.split(batch_size),real_labels.split(batch_size)):\n",
        "            #improving D\n",
        "        z = gpu(torch.empty(batch_size,z_dim).normal_())\n",
        "        #\n",
        "        # your code here\n",
        "        # hint: https://discuss.pytorch.org/t/convert-int-into-one-hot-format/507/4\n",
        "        #\n",
        "        #\n",
        "        #\n",
        "        #\n",
        "        \n",
        "        loss = -torch.mean(torch.log(1-D_scores_on_fake) + torch.log(D_scores_on_real))\n",
        "        optimizer_CD.zero_grad()\n",
        "        loss.backward()\n",
        "        optimizer_CD.step()\n",
        "        loss_D += loss\n",
        "            \n",
        "            # improving G\n",
        "        z = gpu(torch.empty(batch_size,z_dim).normal_())\n",
        "        #\n",
        "        # your code here\n",
        "        #\n",
        "        #\n",
        "        #\n",
        "        #\n",
        "            \n",
        "        loss = -torch.mean(torch.log(D_scores_on_fake))\n",
        "        optimizer_CG.zero_grad()\n",
        "        loss.backward()\n",
        "        optimizer_CG.step()\n",
        "        loss_G += loss\n",
        "                    \n",
        "    loss_D_epoch.append(loss_D)\n",
        "    loss_G_epoch.append(loss_G)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "rfePSB-3Qq3Z",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "plt.plot(loss_D_epoch)\n",
        "plt.plot(loss_G_epoch)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "dRCs9kKBQq3c",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "z = gpu(torch.empty(n_samples,z_dim).normal_())\n",
        "label = torch.LongTensor(n_samples,1).random_() % label_dim\n",
        "label_onehot = torch.FloatTensor(n_samples, label_dim).zero_()\n",
        "label_onehot = gpu(label_onehot.scatter_(1, label, 1))\n",
        "fake_samples = net_CG(z, label_onehot)\n",
        "fake_data = fake_samples.cpu().data.numpy()"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "eGZLxgwKQq3g",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "fig, ax = plt.subplots(1, 1, facecolor='#4B6EA9')\n",
        "plot_data(ax, fake_data, label.squeeze().numpy())\n",
        "plot_data(ax, X, Y, 'spring')\n",
        "plt.show()"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "gQST3OM8Qq3m",
        "colab_type": "text"
      },
      "source": [
        "# Info GAN\n",
        "\n",
        "Implement the [algorithm](https://arxiv.org/abs/1606.03657).\n",
        "\n",
        "This time, you do not have access to the labels but you know there are two classes. The idea is then to provide as in the conditional GAN a random label to the generator but in opposition to the conditional GAN, the discriminator cannot take as input the label (since they are not provided to us) but instead the discriminator will predict a label and this predictor can be trained on fake samples!"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "4jXzrG6iQq3n",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "import torch.nn.functional as F\n",
        "\n",
        "z_dim = 32\n",
        "hidden_dim = 128\n",
        "label_dim = 2\n",
        "\n",
        "\n",
        "class Igenerator(nn.Module):\n",
        "    def __init__(self,z_dim = z_dim, label_dim=label_dim,hidden_dim =hidden_dim):\n",
        "        super(Igenerator,self).__init__()\n",
        "        self.net = nn.Sequential(nn.Linear(z_dim+label_dim,hidden_dim),\n",
        "                     nn.ReLU(), nn.Linear(hidden_dim, 2))\n",
        "        \n",
        "    def forward(self, input, label_onehot):\n",
        "        x = torch.cat([input, label_onehot], 1)\n",
        "        return self.net(x)\n",
        "    \n",
        "class Idiscriminator(nn.Module):\n",
        "    def __init__(self,z_dim = z_dim, label_dim=label_dim,hidden_dim =hidden_dim):\n",
        "        super(Idiscriminator,self).__init__()\n",
        "        self.fc1 = nn.Linear(2,hidden_dim)\n",
        "        self.fc2 = nn.Linear(hidden_dim,1)\n",
        "        self.fc3 = nn.Linear(hidden_dim,1)\n",
        "        \n",
        "    def forward(self, input):\n",
        "        x = F.relu(self.fc1(input))\n",
        "        output = torch.sigmoid(self.fc2(x))\n",
        "        est_label = torch.sigmoid(self.fc3(x)) \n",
        "        return output, est_label\n",
        "        \n",
        "\n",
        "net_IG = gpu(Igenerator())\n",
        "net_ID = gpu(Idiscriminator())"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "3MgxSjDJQq3q",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "batch_size = 50\n",
        "lr = 1e-3\n",
        "nb_epochs = 1000\n",
        "loss_fn = nn.BCELoss()\n",
        "\n",
        "optimizer_IG = torch.optim.Adam(net_IG.parameters(),lr=lr)\n",
        "optimizer_ID = torch.optim.Adam(net_ID.parameters(),lr=lr)\n",
        "loss_D_epoch = []\n",
        "loss_G_epoch = []\n",
        "for e in range(nb_epochs):\n",
        "    \n",
        "    rperm = np.random.permutation(X.shape[0]);\n",
        "    np.take(X,rperm,axis=0,out=X);\n",
        "    #np.take(Y,rperm,axis=0,out=Y);\n",
        "    real_samples = torch.from_numpy(X).type(torch.FloatTensor)\n",
        "    #real_labels = torch.from_numpy(Y).type(torch.LongTensor)\n",
        "    loss_G = 0\n",
        "    loss_D = 0\n",
        "    for real_batch in real_samples.split(batch_size):\n",
        "            #improving D\n",
        "        z = gpu(torch.empty(batch_size,z_dim).normal_())\n",
        "        #\n",
        "        # your code here\n",
        "        #\n",
        "        #\n",
        "            \n",
        "        lossD = -torch.mean(torch.log(1-D_scores_on_fake) + torch.log(D_scores_on_real))\n",
        "        #\n",
        "        # your code here\n",
        "        #\n",
        "        #\n",
        "        #\n",
        "            \n",
        "            # improving G\n",
        "        z = gpu(torch.empty(batch_size,z_dim).normal_())\n",
        "        #\n",
        "        # your code here\n",
        "        #\n",
        "        #\n",
        "        \n",
        "        lossG = -torch.mean(torch.log(D_scores_on_fake))\n",
        "        #\n",
        "        # your code here\n",
        "        #\n",
        "        #\n",
        "        optimizer_IG.zero_grad()\n",
        "        lossG.backward()\n",
        "        optimizer_IG.step()\n",
        "        loss_G += lossG\n",
        "            \n",
        "    loss_D_epoch.append(loss_D)\n",
        "    loss_G_epoch.append(loss_G)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "xUV6GtxHQq3s",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "plt.plot(loss_D_epoch)\n",
        "plt.plot(loss_G_epoch)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Y1N3rpLLQq3v",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "z = gpu(torch.empty(n_samples,z_dim).normal_())\n",
        "label = torch.LongTensor(n_samples,1).random_() % label_dim\n",
        "label_onehot = torch.FloatTensor(n_samples, label_dim).zero_()\n",
        "label_onehot = gpu(label_onehot.scatter_(1, label, 1))\n",
        "fake_samples = net_IG(z, label_onehot)\n",
        "fake_data = fake_samples.cpu().data.numpy()"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "eDpfBXXUQq3y",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "fig, ax = plt.subplots(1, 1, facecolor='#4B6EA9')\n",
        "#ax.set_xlim(x_min, x_max)\n",
        "#ax.set_ylim(y_min, y_max)\n",
        "plot_data(ax, fake_data, label.squeeze().numpy())\n",
        "plot_data(ax, X, Y, 'spring')\n",
        "plt.show()"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "5AX4cWd5Qq34",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        ""
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Y_w9ZH85Qq36",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        ""
      ],
      "execution_count": 0,
      "outputs": []
    }
  ]
}